{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ic4_occAAiAT"
   },
   "source": [
    "##### Copyright 2019 The TensorFlow Authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2023-11-08T00:06:16.155134Z",
     "iopub.status.busy": "2023-11-08T00:06:16.154481Z",
     "iopub.status.idle": "2023-11-08T00:06:16.158680Z",
     "shell.execute_reply": "2023-11-08T00:06:16.158049Z"
    },
    "id": "ioaprt5q5US7"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2023-11-08T00:06:16.162165Z",
     "iopub.status.busy": "2023-11-08T00:06:16.161683Z",
     "iopub.status.idle": "2023-11-08T00:06:16.165232Z",
     "shell.execute_reply": "2023-11-08T00:06:16.164580Z"
    },
    "id": "yCl0eTNH5RS3"
   },
   "outputs": [],
   "source": [
    "#@title MIT License\n",
    "#\n",
    "# Copyright (c) 2017 François Chollet\n",
    "#\n",
    "# Permission is hereby granted, free of charge, to any person obtaining a\n",
    "# copy of this software and associated documentation files (the \"Software\"),\n",
    "# to deal in the Software without restriction, including without limitation\n",
    "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n",
    "# and/or sell copies of the Software, and to permit persons to whom the\n",
    "# Software is furnished to do so, subject to the following conditions:\n",
    "#\n",
    "# The above copyright notice and this permission notice shall be included in\n",
    "# all copies or substantial portions of the Software.\n",
    "#\n",
    "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
    "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
    "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n",
    "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
    "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n",
    "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n",
    "# DEALINGS IN THE SOFTWARE."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ItXfxkxvosLH"
   },
   "source": [
    "# 电影评论文本分类"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Eg62Pmz3o83v"
   },
   "source": [
    "本教程演示了从存储在磁盘上的纯文本文件开始的文本分类。您将训练一个二元分类器对 IMDB 数据集执行情感分析。在笔记本的最后，有一个练习供您尝试，您将在其中训练一个多类分类器来预测 Stack Overflow 上编程问题的标签。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:06:16.169095Z",
     "iopub.status.busy": "2023-11-08T00:06:16.168538Z",
     "iopub.status.idle": "2023-11-08T00:06:19.034768Z",
     "shell.execute_reply": "2023-11-08T00:06:19.033921Z"
    },
    "id": "8RZOuS9LWQvv"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import re\n",
    "import shutil\n",
    "import string\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import losses\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:06:19.039547Z",
     "iopub.status.busy": "2023-11-08T00:06:19.038664Z",
     "iopub.status.idle": "2023-11-08T00:06:19.042939Z",
     "shell.execute_reply": "2023-11-08T00:06:19.042282Z"
    },
    "id": "6-tTFS04dChr"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.10.0\n"
     ]
    }
   ],
   "source": [
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NBTI1bi8qdFV"
   },
   "source": [
    "## 情感分析\n",
    "\n",
    "此笔记本训练了一个情感分析模型，利用评论文本将电影评论分类为*正面*或*负面*评价。这是一个*二元*（或二类）分类示例，也是一个重要且应用广泛的机器学习问题。\n",
    "\n",
    "您将使用 [Large Movie Review Dataset](https://ai.stanford.edu/~amaas/data/sentiment/)，其中包含 [Internet Movie Database](https://www.imdb.com/) 中的 50,000 条电影评论文本 。我们将这些评论分为两组，其中 25,000 条用于训练，另外 25,000 条用于测试。训练集和测试集是*均衡的*，也就是说其中包含相等数量的正面评价和负面评价。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iAsKG535pHep"
   },
   "source": [
    "### 下载并探索 IMDB 数据集\n",
    "\n",
    "我们下载并提取数据集，然后浏览一下目录结构。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:06:19.046490Z",
     "iopub.status.busy": "2023-11-08T00:06:19.045918Z",
     "iopub.status.idle": "2023-11-08T00:06:59.349596Z",
     "shell.execute_reply": "2023-11-08T00:06:59.348735Z"
    },
    "id": "k7ZYnuajVlFN"
   },
   "outputs": [],
   "source": [
    "url = \"https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\"\n",
    "\n",
    "dataset = tf.keras.utils.get_file(\"aclImdb_v1\", url,\n",
    "                                    untar=True, cache_dir='.',\n",
    "                                    cache_subdir='')\n",
    "\n",
    "dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:06:59.353724Z",
     "iopub.status.busy": "2023-11-08T00:06:59.353410Z",
     "iopub.status.idle": "2023-11-08T00:06:59.360596Z",
     "shell.execute_reply": "2023-11-08T00:06:59.359981Z"
    },
    "id": "355CfOvsV1pl"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['imdb.vocab', 'imdbEr.txt', 'README', 'test', 'train']"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.listdir(dataset_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:06:59.363856Z",
     "iopub.status.busy": "2023-11-08T00:06:59.363202Z",
     "iopub.status.idle": "2023-11-08T00:06:59.368224Z",
     "shell.execute_reply": "2023-11-08T00:06:59.367563Z"
    },
    "id": "7ASND15oXpF1"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['labeledBow.feat',\n",
       " 'neg',\n",
       " 'pos',\n",
       " 'unsup',\n",
       " 'unsupBow.feat',\n",
       " 'urls_neg.txt',\n",
       " 'urls_pos.txt',\n",
       " 'urls_unsup.txt']"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dir = os.path.join(dataset_dir, 'train')\n",
    "os.listdir(train_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ysMNMI1CWDFD"
   },
   "source": [
    "`aclImdb/train/pos` 和 `aclImdb/train/neg` 目录包含许多文本文件，每个文件都是一条电影评论。我们来看看其中的一条评论。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:06:59.371538Z",
     "iopub.status.busy": "2023-11-08T00:06:59.370977Z",
     "iopub.status.idle": "2023-11-08T00:06:59.375199Z",
     "shell.execute_reply": "2023-11-08T00:06:59.374543Z"
    },
    "id": "R7g8hFvzWLIZ"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rachel Griffiths writes and directs this award winning short film. A heartwarming story about coping with grief and cherishing the memory of those we've loved and lost. Although, only 15 minutes long, Griffiths manages to capture so much emotion and truth onto film in the short space of time. Bud Tingwell gives a touching performance as Will, a widower struggling to cope with his wife's death. Will is confronted by the harsh reality of loneliness and helplessness as he proceeds to take care of Ruth's pet cow, Tulip. The film displays the grief and responsibility one feels for those they have loved and lost. Good cinematography, great direction, and superbly acted. It will bring tears to all those who have lost a loved one, and survived.\n"
     ]
    }
   ],
   "source": [
    "sample_file = os.path.join(train_dir, 'pos/1181_9.txt')\n",
    "with open(sample_file) as f:\n",
    "  print(f.read())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Mk20TEm6ZRFP"
   },
   "source": [
    "### 加载数据集\n",
    "\n",
    "接下来，您将从磁盘加载数据并将其准备为适合训练的格式。为此，您将使用有用的 [text_dataset_from_directory](https://tensorflow.google.cn/api_docs/python/tf/keras/preprocessing/text_dataset_from_directory) 实用工具，它期望的目录结构如下所示。\n",
    "\n",
    "```\n",
    "main_directory/\n",
    "...class_a/\n",
    "......a_text_1.txt\n",
    "......a_text_2.txt\n",
    "...class_b/\n",
    "......b_text_1.txt\n",
    "......b_text_2.txt\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nQauv38Lnok3"
   },
   "source": [
    "要准备用于二元分类的数据集，磁盘上需要有两个文件夹，分别对应于 `class_a` 和 `class_b`。这些将是正面和负面的电影评论，可以在 `aclImdb/train/pos` 和 `aclImdb/train/neg` 中找到。由于 IMDB 数据集包含其他文件夹，因此您需要在使用此实用工具之前将其移除。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:06:59.378678Z",
     "iopub.status.busy": "2023-11-08T00:06:59.378111Z",
     "iopub.status.idle": "2023-11-08T00:07:00.250491Z",
     "shell.execute_reply": "2023-11-08T00:07:00.249739Z"
    },
    "id": "VhejsClzaWfl"
   },
   "outputs": [],
   "source": [
    "remove_dir = os.path.join(train_dir, 'unsup')\n",
    "shutil.rmtree(remove_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "95kkUdRoaeMw"
   },
   "source": [
    "接下来，您将使用 `text_dataset_from_directory` 实用工具创建带标签的 `tf.data.Dataset`。[tf.data](https://tensorflow.google.cn/guide/data) 是一组强大的数据处理工具。\n",
    "\n",
    "运行机器学习实验时，最佳做法是将数据集拆成三份：[训练](https://developers.google.com/machine-learning/glossary#training_set)、[验证](https://developers.google.com/machine-learning/glossary#validation_set) 和 [测试](https://developers.google.com/machine-learning/glossary#test-set)。\n",
    "\n",
    "IMDB 数据集已经分成训练集和测试集，但缺少验证集。我们来通过下面的 `validation_split` 参数，使用 80:20 拆分训练数据来创建验证集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:00.255167Z",
     "iopub.status.busy": "2023-11-08T00:07:00.254452Z",
     "iopub.status.idle": "2023-11-08T00:07:04.840458Z",
     "shell.execute_reply": "2023-11-08T00:07:04.839428Z"
    },
    "id": "nOrK-MTYaw3C"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 25000 files belonging to 2 classes.\n",
      "Using 20000 files for training.\n"
     ]
    }
   ],
   "source": [
    "batch_size = 128\n",
    "seed = 42\n",
    "\n",
    "raw_train_ds = tf.keras.utils.text_dataset_from_directory(\n",
    "    'aclImdb/train', \n",
    "    batch_size=batch_size, \n",
    "    validation_split=0.2, \n",
    "    subset='training', \n",
    "    seed=seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5Y33oxOUpYkh"
   },
   "source": [
    "如上所示，训练文件夹中有 25,000 个样本，您将使用其中的 80%（或 20,000 个）进行训练。稍后您将看到，您可以通过将数据集直接传递给 `model.fit` 来训练模型。如果您不熟悉 `tf.data`，还可以遍历数据集并打印出一些样本，如下所示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:04.845302Z",
     "iopub.status.busy": "2023-11-08T00:07:04.844561Z",
     "iopub.status.idle": "2023-11-08T00:07:04.880991Z",
     "shell.execute_reply": "2023-11-08T00:07:04.880202Z"
    },
    "id": "51wNaPPApk1K"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Review b\"CAROL'S JOURNEY is a pleasure to watch for so many reasons. The acting of Clara Lago is simply amazing for someone so young, and she is one of those special actors who can say say much with facial expressions. Director Imanol Urbibe presents a tight and controlled film with no break in continuity, thereby propelling the plot at a steady pace with just enough suspense to keep one wondering what the nest scene will bring. The screenplay of Angel Garcia Roldan is story telling at its best, which, it seems, if the major purpose for films after all. The plot is unpredictable, yet the events as they unravel are completely logical. Perhaps the best feature of this film if to tell a story of the Spanish Civil War as it affected the people. It was a major event of the 20th century, yet hardly any Americans know of it. In fact, in 40 years of university teaching, I averaged about one student a semester who had even heard of it, much less any who could say anything comprehensive about it--and the overwhelming number of students were merit scholars, all of which speaks to the enormous amount of censorship in American education. So, in one way, this film is a good way to begin a study of that event, keeping in mind that when one thread is pulled a great deal of history is unraveled. The appreciation of this film is, therefore, in direct relation to the amount of one's knowledge. To view this film as another coming of age movie is the miss the movie completely. The Left Elbow Index considers seven aspects of film-- acting, production sets, character development, plot, dialogue, film continuity, and artistry--on a scale for 10 for very good, 5 for average, and 1 for needs help. CAROL'S JOURNEY is above average on all counts, excepting dialogue which is rated as average. The LEI average for this film is 9.3, raised to a 10 when equated to the IMDb scale. I highly recommend this film for all ages.\"\n",
      "Label 1\n",
      "Review b\"David Mamet is a very interesting and a very un-equal director. His first movie 'House of Games' was the one I liked best, and it set a series of films with characters whose perspective of life changes as they get into complicated situations, and so does the perspective of the viewer.<br /><br />So is 'Homicide' which from the title tries to set the mind of the viewer to the usual crime drama. The principal characters are two cops, one Jewish and one Irish who deal with a racially charged area. The murder of an old Jewish shop owner who proves to be an ancient veteran of the Israeli Independence war triggers the Jewish identity in the mind and heart of the Jewish detective.<br /><br />This is were the flaws of the film are the more obvious. The process of awakening is theatrical and hard to believe, the group of Jewish militants is operatic, and the way the detective eventually walks to the final violent confrontation is pathetic. The end of the film itself is Mamet-like smart, but disappoints from a human emotional perspective.<br /><br />Joe Mantegna and William Macy give strong performances, but the flaws of the story are too evident to be easily compensated.\"\n",
      "Label 0\n",
      "Review b'After Harry Reems\\' teenage girlfriend is raped by Zebbedy Colt (The Night-Walker), Reems becomes despondent and consoles himself by having sex with some lesbians. Meanwhile, Colt, who carries a cane and dresses like a magician, rapes some more women. Eventually, Reems decides to track him down and end his crime spree. Despite being shot on film and marginally nasty, it looks like any other 70\\'s porno and is ineptly executed. The rape/abuse scenes are surprisingly restrained and the attempt to cash in on \"Death Wish\" is laughable. R. Bolla (\"Cannibal Holocaust\") plays a cop. Colt, who is usually over-the-top, wigs out in a couple of scenes, but he\\'s too well behaved for my money. This roughie could have been much rougher.'\n",
      "Label 0\n"
     ]
    }
   ],
   "source": [
    "for text_batch, label_batch in raw_train_ds.take(1):\n",
    "  for i in range(3):\n",
    "    print(\"Review\", text_batch.numpy()[i])\n",
    "    print(\"Label\", label_batch.numpy()[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JWq1SUIrp1a-"
   },
   "source": [
    "请注意，评论包含原始文本（带有标点符号和偶尔出现的 HTML 代码，如 `<br/>`）。我们将在以下部分展示如何处理这些问题。\n",
    "\n",
    "标签为 0 或 1。要查看它们与正面和负面电影评论的对应关系，可以查看数据集上的 `class_names` 属性。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:04.884996Z",
     "iopub.status.busy": "2023-11-08T00:07:04.884209Z",
     "iopub.status.idle": "2023-11-08T00:07:04.888550Z",
     "shell.execute_reply": "2023-11-08T00:07:04.887775Z"
    },
    "id": "MlICTG8spyO2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Label 0 corresponds to neg\n",
      "Label 1 corresponds to pos\n"
     ]
    }
   ],
   "source": [
    "print(\"Label 0 corresponds to\", raw_train_ds.class_names[0])\n",
    "print(\"Label 1 corresponds to\", raw_train_ds.class_names[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pbdO39vYqdJr"
   },
   "source": [
    "接下来，您将创建验证数据集和测试数据集。您将使用训练集中剩余的 5,000 条评论进行验证。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SzxazN8Hq1pF"
   },
   "source": [
    "注：使用 `validation_split` 和 `subset` 参数时，请确保要么指定随机种子，要么传递 `shuffle=False`，这样验证拆分和训练拆分就不会重叠。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:04.892296Z",
     "iopub.status.busy": "2023-11-08T00:07:04.891639Z",
     "iopub.status.idle": "2023-11-08T00:07:06.118029Z",
     "shell.execute_reply": "2023-11-08T00:07:06.117313Z"
    },
    "id": "JsMwwhOoqjKF"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 25000 files belonging to 2 classes.\n",
      "Using 5000 files for validation.\n"
     ]
    }
   ],
   "source": [
    "raw_val_ds = tf.keras.utils.text_dataset_from_directory(\n",
    "    'aclImdb/train', \n",
    "    batch_size=batch_size, \n",
    "    validation_split=0.2, \n",
    "    subset='validation', \n",
    "    seed=seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:06.121679Z",
     "iopub.status.busy": "2023-11-08T00:07:06.121425Z",
     "iopub.status.idle": "2023-11-08T00:07:07.398188Z",
     "shell.execute_reply": "2023-11-08T00:07:07.397398Z"
    },
    "id": "rdSr0Nt3q_ns"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 25000 files belonging to 2 classes.\n"
     ]
    }
   ],
   "source": [
    "raw_test_ds = tf.keras.utils.text_dataset_from_directory(\n",
    "    'aclImdb/test', \n",
    "    batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qJmTiO0IYAjm"
   },
   "source": [
    "### 准备用于训练的数据集\n",
    "\n",
    "接下来，您将使用有用的 `tf.keras.layers.TextVectorization` 层对数据进行标准化、词例化和向量化。\n",
    "\n",
    "标准化是指对文本进行预处理，通常是移除标点符号或 HTML 元素以简化数据集。词例化是指将字符串分割成词例（例如，通过空格将句子分割成单个单词）。向量化是指将词例转换为数字，以便将它们输入神经网络。所有这些任务都可以通过这个层完成。\n",
    "\n",
    "正如您在上面看到的，评论包含各种 HTML 代码，例如 `<br />`。`TextVectorization` 层（默认情况下会将文本转换为小写并去除标点符号，但不会去除 HTML）中的默认标准化程序不会移除这些代码。您将编写一个自定义标准化函数来移除 HTML。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZVcHl-SLrH-u"
   },
   "source": [
    "注：为了防止[训练-测试偏差](https://developers.google.com/machine-learning/guides/rules-of-ml#training-serving_skew)（也称为训练-应用偏差），在训练和测试时间对数据进行相同的预处理非常重要。为此，可以将 `TextVectorization` 层直接包含在模型中，如本教程后面所示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:07.402607Z",
     "iopub.status.busy": "2023-11-08T00:07:07.402334Z",
     "iopub.status.idle": "2023-11-08T00:07:07.406673Z",
     "shell.execute_reply": "2023-11-08T00:07:07.406003Z"
    },
    "id": "SDRI_s_tX1Hk"
   },
   "outputs": [],
   "source": [
    "def custom_standardization(input_data):\n",
    "  lowercase = tf.strings.lower(input_data)\n",
    "  stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')\n",
    "  return tf.strings.regex_replace(stripped_html,\n",
    "                                  '[%s]' % re.escape(string.punctuation),\n",
    "                                  '')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "d2d3Aw8dsUux"
   },
   "source": [
    "<br>接下来，您将创建一个 `TextVectorization` 层。您将使用该层对我们的数据进行标准化、词例化和向量化。您将 `output_mode` 设置为 `int` 以便为每个词例创建唯一的整数索引。\n",
    "\n",
    "请注意，您使用的是默认拆分函数，以及您在上面定义的自定义标准化函数。您还将为模型定义一些常量，例如显式的最大 `sequence_length`，这会使层将序列填充或截断为精确的 `sequence_length` 值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:07.410143Z",
     "iopub.status.busy": "2023-11-08T00:07:07.409843Z",
     "iopub.status.idle": "2023-11-08T00:07:07.425602Z",
     "shell.execute_reply": "2023-11-08T00:07:07.424907Z"
    },
    "id": "-c76RvSzsMnX"
   },
   "outputs": [],
   "source": [
    "max_features = 10000\n",
    "sequence_length = 250\n",
    "\n",
    "vectorize_layer = layers.TextVectorization(\n",
    "    standardize=custom_standardization,\n",
    "    max_tokens=max_features,\n",
    "    output_mode='int',\n",
    "    output_sequence_length=sequence_length)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vlFOpfF6scT6"
   },
   "source": [
    "接下来，您将调用 `adapt` 以使预处理层的状态适合数据集。这会使模型构建字符串到整数的索引。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lAhdjK7AtroA"
   },
   "source": [
    "注：在调用时请务必仅使用您的训练数据（使用测试集会泄漏信息）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:07.429658Z",
     "iopub.status.busy": "2023-11-08T00:07:07.428989Z",
     "iopub.status.idle": "2023-11-08T00:07:09.966421Z",
     "shell.execute_reply": "2023-11-08T00:07:09.965506Z"
    },
    "id": "GH4_2ZGJsa_X"
   },
   "outputs": [],
   "source": [
    "# Make a text-only dataset (without labels), then call adapt\n",
    "train_text = raw_train_ds.map(lambda x, y: x)\n",
    "vectorize_layer.adapt(train_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SHQVEFzNt-K_"
   },
   "source": [
    "我们来创建一个函数来查看使用该层预处理一些数据的结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:09.971004Z",
     "iopub.status.busy": "2023-11-08T00:07:09.970348Z",
     "iopub.status.idle": "2023-11-08T00:07:09.974465Z",
     "shell.execute_reply": "2023-11-08T00:07:09.973774Z"
    },
    "id": "SCIg_T50wOCU"
   },
   "outputs": [],
   "source": [
    "def vectorize_text(text, label):\n",
    "  text = tf.expand_dims(text, -1)\n",
    "  return vectorize_layer(text), label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:09.977759Z",
     "iopub.status.busy": "2023-11-08T00:07:09.977506Z",
     "iopub.status.idle": "2023-11-08T00:07:10.048355Z",
     "shell.execute_reply": "2023-11-08T00:07:10.047585Z"
    },
    "id": "XULcm6B3xQIO"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Review tf.Tensor(b\"Hello. I am Paul Raddick, a.k.a. Panic Attack of WTAF, Channel 29 in Philadelphia. Let me tell you about this god awful movie that powered on Adam Sandler's film career but was digitized after a short time.<br /><br />Going Overboard is about an aspiring comedian played by Sandler who gets a job on a cruise ship and fails...or so I thought. Sandler encounters babes that like History of the World Part 1 and Rebound. The babes were supposed to be engaged, but, actually, they get executed by Sawtooth, the meanest cannibal the world has ever known. Adam Sandler fared bad in Going Overboard, but fared better in Big Daddy, Billy Madison, and Jen Leone's favorite, 50 First Dates. Man, Drew Barrymore was one hot chick. Spanglish is red hot, Going Overboard ain't Dooley squat! End of file.\", shape=(), dtype=string)\n",
      "Label neg\n",
      "Vectorized review (<tf.Tensor: shape=(1, 250), dtype=int64, numpy=\n",
      "array([[5128,   10,  237,  700,    1, 2051, 4039, 1235,    5,    1, 1263,\n",
      "           1,    8, 9738,  370,   69,  361,   22,   42,   11,  596,  365,\n",
      "          17,   12,    1,   20, 1909,    1,   19,  608,   18,   13,    1,\n",
      "         101,    4,  350,   58,  168, 6819,    7,   42,   33, 5079, 3052,\n",
      "         248,   32, 3231,   36,  201,    4,  283,   20,    4, 3647, 1684,\n",
      "           3,    1,   37,   10,  197, 3231, 3283, 7272,   12,   38,  468,\n",
      "           5,    2,  182,  170,  473,    3,    1,    2, 7272,   65,  412,\n",
      "           6,   26, 4054,   18,  157,   34,   75, 2202,   32,    1,    2,\n",
      "           1, 5146,    2,  182,   43,  121,  630, 1909, 3231,    1,   80,\n",
      "           8,  168, 6819,   18,    1,  122,    8,  199, 4347, 1518,    1,\n",
      "           3,    1,    1,  500, 1732,   83, 5335,  130, 2161, 5220,   13,\n",
      "          28,  931, 2280,    1,    7,  740,  931,  168, 6819, 2677,    1,\n",
      "           1,  124,    5, 8101,    0,    0,    0,    0,    0,    0,    0,\n",
      "           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
      "           0,    0,    0,    0,    0,    0,    0,    0]], dtype=int64)>, <tf.Tensor: shape=(), dtype=int32, numpy=0>)\n"
     ]
    }
   ],
   "source": [
    "# retrieve a batch (of 32 reviews and labels) from the dataset\n",
    "text_batch, label_batch = next(iter(raw_train_ds))\n",
    "first_review, first_label = text_batch[0], label_batch[0]\n",
    "print(\"Review\", first_review)\n",
    "print(\"Label\", raw_train_ds.class_names[first_label])\n",
    "print(\"Vectorized review\", vectorize_text(first_review, first_label))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6u5EX0hxyNZT"
   },
   "source": [
    "正如您在上面看到的，每个词例都被一个整数替换了。您可以通过在该层上调用 `.get_vocabulary()` 来查找每个整数对应的词例（字符串）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:10.051939Z",
     "iopub.status.busy": "2023-11-08T00:07:10.051643Z",
     "iopub.status.idle": "2023-11-08T00:07:10.101341Z",
     "shell.execute_reply": "2023-11-08T00:07:10.100607Z"
    },
    "id": "kRq9hTQzhVhW"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1287 --->  silent\n",
      " 313 --->  night\n",
      "Vocabulary size: 10000\n"
     ]
    }
   ],
   "source": [
    "print(\"1287 ---> \",vectorize_layer.get_vocabulary()[1287])\n",
    "print(\" 313 ---> \",vectorize_layer.get_vocabulary()[313])\n",
    "print('Vocabulary size: {}'.format(len(vectorize_layer.get_vocabulary())))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XD2H6utRydGv"
   },
   "source": [
    "差不多可以训练您的模型了。作为最后的预处理步骤，将之前创建的 TextVectorization 层应用于训练数据集、验证数据集和测试数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:10.104863Z",
     "iopub.status.busy": "2023-11-08T00:07:10.104588Z",
     "iopub.status.idle": "2023-11-08T00:07:10.240596Z",
     "shell.execute_reply": "2023-11-08T00:07:10.239797Z"
    },
    "id": "2zhmpeViI1iG"
   },
   "outputs": [],
   "source": [
    "train_ds = raw_train_ds.map(vectorize_text)\n",
    "val_ds = raw_val_ds.map(vectorize_text)\n",
    "test_ds = raw_test_ds.map(vectorize_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YsVQyPMizjuO"
   },
   "source": [
    "### 配置数据集以提高性能\n",
    "\n",
    "以下是加载数据时应该使用的两种重要方法，以确保 I/O 不会阻塞。\n",
    "\n",
    "从磁盘加载后，`.cache()` 会将数据保存在内存中。这将确保数据集在训练模型时不会成为瓶颈。如果您的数据集太大而无法放入内存，也可以使用此方法创建高性能的磁盘缓存，这比许多小文件的读取效率更高。\n",
    "\n",
    "`prefetch()` 会在训练时将数据预处理和模型执行重叠。\n",
    "\n",
    "您可以在[数据性能指南](https://tensorflow.google.cn/guide/data_performance)中深入了解这两种方法，以及如何将数据缓存到磁盘。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:10.244768Z",
     "iopub.status.busy": "2023-11-08T00:07:10.244493Z",
     "iopub.status.idle": "2023-11-08T00:07:10.257482Z",
     "shell.execute_reply": "2023-11-08T00:07:10.256763Z"
    },
    "id": "wMcs_H7izm5m"
   },
   "outputs": [],
   "source": [
    "AUTOTUNE = tf.data.AUTOTUNE\n",
    "\n",
    "train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
    "val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
    "test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LLC02j2g-llC"
   },
   "source": [
    "### 创建模型\n",
    "\n",
    "是时候创建您的神经网络了："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:10.261022Z",
     "iopub.status.busy": "2023-11-08T00:07:10.260760Z",
     "iopub.status.idle": "2023-11-08T00:07:10.264087Z",
     "shell.execute_reply": "2023-11-08T00:07:10.263392Z"
    },
    "id": "dkQP6in8yUBR"
   },
   "outputs": [],
   "source": [
    "embedding_dim = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:10.267425Z",
     "iopub.status.busy": "2023-11-08T00:07:10.267153Z",
     "iopub.status.idle": "2023-11-08T00:07:10.336069Z",
     "shell.execute_reply": "2023-11-08T00:07:10.335303Z"
    },
    "id": "xpKOoWgu-llD"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " embedding_1 (Embedding)     (None, None, 16)          160016    \n",
      "                                                                 \n",
      " dropout_2 (Dropout)         (None, None, 16)          0         \n",
      "                                                                 \n",
      " global_average_pooling1d_1   (None, 16)               0         \n",
      " (GlobalAveragePooling1D)                                        \n",
      "                                                                 \n",
      " dropout_3 (Dropout)         (None, 16)                0         \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 1)                 17        \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 160,033\n",
      "Trainable params: 160,033\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = tf.keras.Sequential([\n",
    "  layers.Embedding(max_features + 1, embedding_dim),\n",
    "  layers.Dropout(0.2),\n",
    "  layers.GlobalAveragePooling1D(),\n",
    "  layers.Dropout(0.2),\n",
    "  layers.Dense(1)])\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6PbKQ6mucuKL"
   },
   "source": [
    "层按顺序堆叠以构建分类器：\n",
    "\n",
    "1. 第一个层是 `Embedding` 层。此层采用整数编码的评论，并查找每个单词索引的嵌入向量。这些向量是通过模型训练学习到的。向量向输出数组增加了一个维度。得到的维度为：`(batch, sequence, embedding)`。要详细了解嵌入向量，请参阅[单词嵌入向量](https://tensorflow.google.cn/text/guide/word_embeddings)教程。\n",
    "2. 接下来，`GlobalAveragePooling1D` 将通过对序列维度求平均值来为每个样本返回一个定长输出向量。这允许模型以尽可能最简单的方式处理变长输入。\n",
    "3. 该定长输出向量通过一个有 16 个隐层单元的全连接（`Dense`）层传输。\n",
    "4. 最后一层与单个输出结点密集连接。使用 `Sigmoid` 激活函数，其函数值为介于 0 与 1 之间的浮点数，表示概率或置信度。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L4EqVWg4-llM"
   },
   "source": [
    "### 损失函数与优化器\n",
    "\n",
    "一个模型需要损失函数和优化器来进行训练。由于这是一个二分类问题且模型输出概率值（一个使用 sigmoid 激活函数的单一单元层），我们将使用 `binary_crossentropy` 损失函数。\n",
    "\n",
    "这不是损失函数的唯一选择，例如，您可以选择 `mean_squared_error` 。但是，一般来说 `binary_crossentropy` 更适合处理概率——它能够度量概率分布之间的“距离”，或者在我们的示例中，指的是度量 ground-truth 分布与预测值之间的“距离”。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:10.346063Z",
     "iopub.status.busy": "2023-11-08T00:07:10.345767Z",
     "iopub.status.idle": "2023-11-08T00:07:10.366284Z",
     "shell.execute_reply": "2023-11-08T00:07:10.365559Z"
    },
    "id": "Mr0GP-cQ-llN"
   },
   "outputs": [],
   "source": [
    "model.compile(loss=losses.BinaryCrossentropy(from_logits=True),\n",
    "              optimizer='adam',\n",
    "              metrics=tf.metrics.BinaryAccuracy(threshold=0.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "35jv_fzP-llU"
   },
   "source": [
    "### 训练模型\n",
    "\n",
    "以 512 个样本的 mini-batch 大小迭代 40 个 epoch 来训练模型。这是指对 `x_train` 和 `y_train` 张量中所有样本的的 40 次迭代。在训练过程中，监测来自验证集的 10,000 个样本上的损失值（loss）和准确率（accuracy）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:07:10.370041Z",
     "iopub.status.busy": "2023-11-08T00:07:10.369560Z",
     "iopub.status.idle": "2023-11-08T00:08:18.123086Z",
     "shell.execute_reply": "2023-11-08T00:08:18.122257Z"
    },
    "id": "tXSGrjWZ-llW"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "157/157 [==============================] - 6s 34ms/step - loss: 0.6876 - binary_accuracy: 0.6513 - val_loss: 0.6790 - val_binary_accuracy: 0.7270\n",
      "Epoch 2/10\n",
      "157/157 [==============================] - 1s 7ms/step - loss: 0.6621 - binary_accuracy: 0.7475 - val_loss: 0.6455 - val_binary_accuracy: 0.7544\n",
      "Epoch 3/10\n",
      "157/157 [==============================] - 1s 7ms/step - loss: 0.6206 - binary_accuracy: 0.7712 - val_loss: 0.6027 - val_binary_accuracy: 0.7748\n",
      "Epoch 4/10\n",
      "157/157 [==============================] - 1s 7ms/step - loss: 0.5745 - binary_accuracy: 0.7918 - val_loss: 0.5599 - val_binary_accuracy: 0.7944\n",
      "Epoch 5/10\n",
      "157/157 [==============================] - 1s 7ms/step - loss: 0.5303 - binary_accuracy: 0.8135 - val_loss: 0.5200 - val_binary_accuracy: 0.8138\n",
      "Epoch 6/10\n",
      "157/157 [==============================] - 1s 7ms/step - loss: 0.4896 - binary_accuracy: 0.8310 - val_loss: 0.4848 - val_binary_accuracy: 0.8276\n",
      "Epoch 7/10\n",
      "157/157 [==============================] - 1s 7ms/step - loss: 0.4550 - binary_accuracy: 0.8468 - val_loss: 0.4546 - val_binary_accuracy: 0.8376\n",
      "Epoch 8/10\n",
      "157/157 [==============================] - 2s 10ms/step - loss: 0.4254 - binary_accuracy: 0.8572 - val_loss: 0.4295 - val_binary_accuracy: 0.8464\n",
      "Epoch 9/10\n",
      "157/157 [==============================] - 2s 10ms/step - loss: 0.4001 - binary_accuracy: 0.8650 - val_loss: 0.4087 - val_binary_accuracy: 0.8504\n",
      "Epoch 10/10\n",
      "157/157 [==============================] - 2s 11ms/step - loss: 0.3784 - binary_accuracy: 0.8701 - val_loss: 0.3913 - val_binary_accuracy: 0.8570\n"
     ]
    }
   ],
   "source": [
    "epochs = 10\n",
    "history = model.fit(\n",
    "    train_ds,\n",
    "    validation_data=val_ds,\n",
    "    epochs=epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9EEGuDVuzb5r"
   },
   "source": [
    "### 评估模型\n",
    "\n",
    "我们来看一下模型的性能如何。将返回两个值。损失值（loss）（一个表示误差的数字，值越低越好）与准确率（accuracy）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:08:18.127359Z",
     "iopub.status.busy": "2023-11-08T00:08:18.126730Z",
     "iopub.status.idle": "2023-11-08T00:08:19.768022Z",
     "shell.execute_reply": "2023-11-08T00:08:19.767150Z"
    },
    "id": "zOMKywn4zReN"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "196/196 [==============================] - 5s 25ms/step - loss: 0.4041 - binary_accuracy: 0.8468\n",
      "Loss:  0.40409862995147705\n",
      "Accuracy:  0.8468000292778015\n"
     ]
    }
   ],
   "source": [
    "loss, accuracy = model.evaluate(test_ds)\n",
    "\n",
    "print(\"Loss: \", loss)\n",
    "print(\"Accuracy: \", accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z1iEXVTR0Z2t"
   },
   "source": [
    "这种十分朴素的方法得到了约 87% 的准确率（accuracy）。若采用更好的方法，模型的准确率应当接近 95%。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ldbQqCw2Xc1W"
   },
   "source": [
    "### 创建准确率和损失随时间变化的图表\n",
    "\n",
    "`model.fit()` 会返回包含一个字典的 `History` 对象。该字典包含训练过程中产生的所有信息："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:08:19.771969Z",
     "iopub.status.busy": "2023-11-08T00:08:19.771334Z",
     "iopub.status.idle": "2023-11-08T00:08:19.776308Z",
     "shell.execute_reply": "2023-11-08T00:08:19.775552Z"
    },
    "id": "-YcvZsdvWfDf"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['loss', 'binary_accuracy', 'val_loss', 'val_binary_accuracy'])"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history_dict = history.history\n",
    "history_dict.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1_CH32qJXruI"
   },
   "source": [
    "其中有四个条目：每个条目代表训练和验证过程中的一项监测指标。您可以使用这些指标来绘制用于比较的训练损失和验证损失图表，以及训练准确率和验证准确率图表："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:08:19.780029Z",
     "iopub.status.busy": "2023-11-08T00:08:19.779288Z",
     "iopub.status.idle": "2023-11-08T00:08:19.939379Z",
     "shell.execute_reply": "2023-11-08T00:08:19.938685Z"
    },
    "id": "2SEMeQ5YXs8z"
   },
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
      "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
      "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "acc = history_dict['binary_accuracy']\n",
    "val_acc = history_dict['val_binary_accuracy']\n",
    "loss = history_dict['loss']\n",
    "val_loss = history_dict['val_loss']\n",
    "\n",
    "epochs = range(1, len(acc) + 1)\n",
    "\n",
    "# \"bo\" is for \"blue dot\"\n",
    "plt.plot(epochs, loss, 'bo', label='Training loss')\n",
    "# b is for \"solid blue line\"\n",
    "plt.plot(epochs, val_loss, 'b', label='Validation loss')\n",
    "plt.title('Training and validation loss')\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:08:19.943145Z",
     "iopub.status.busy": "2023-11-08T00:08:19.942687Z",
     "iopub.status.idle": "2023-11-08T00:08:20.095342Z",
     "shell.execute_reply": "2023-11-08T00:08:20.094616Z"
    },
    "id": "Z3PJemLPXwz_"
   },
   "outputs": [],
   "source": [
    "plt.plot(epochs, acc, 'bo', label='Training acc')\n",
    "plt.plot(epochs, val_acc, 'b', label='Validation acc')\n",
    "plt.title('Training and validation accuracy')\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.legend(loc='lower right')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hFFyCuJoXy7r"
   },
   "source": [
    "在该图表中，虚线代表训练损失和准确率，实线代表验证损失和准确率。\n",
    "\n",
    "请注意，训练损失会逐周期*下降*，而训练准确率则逐周期*上升*。使用梯度下降优化时，这是预期结果，它应该在每次迭代中最大限度减少所需的数量。\n",
    "\n",
    "但是，对于验证损失和准确率来说则不然——它们似乎会在训练转确率之前达到顶点。这是过拟合的一个例子：模型在训练数据上的表现要好于在之前从未见过的数据上的表现。经过这一点之后，模型会过度优化和学习*特定*于训练数据的表示，但无法*泛化*到测试数据。\n",
    "\n",
    "对于这种特殊情况，您可以通过在验证准确率不再增加时直接停止训练来防止过度拟合。一种方式是使用 `tf.keras.callbacks.EarlyStopping` 回调。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-to23J3Vy5d3"
   },
   "source": [
    "## 导出模型\n",
    "\n",
    "在上面的代码中，您在向模型馈送文本之前对数据集应用了 `TextVectorization`。 如果您想让模型能够处理原始字符串（例如，为了简化部署），您可以在模型中包含 `TextVectorization` 层。为此，您可以使用刚刚训练的权重创建一个新模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:08:20.099491Z",
     "iopub.status.busy": "2023-11-08T00:08:20.098808Z",
     "iopub.status.idle": "2023-11-08T00:08:23.463040Z",
     "shell.execute_reply": "2023-11-08T00:08:23.462257Z"
    },
    "id": "FWXsMvryuZuq"
   },
   "outputs": [],
   "source": [
    "export_model = tf.keras.Sequential([\n",
    "  vectorize_layer,\n",
    "  model,\n",
    "  layers.Activation('sigmoid')\n",
    "])\n",
    "\n",
    "export_model.compile(\n",
    "    loss=losses.BinaryCrossentropy(from_logits=False), optimizer=\"adam\", metrics=['accuracy']\n",
    ")\n",
    "\n",
    "# Test it with `raw_test_ds`, which yields raw strings\n",
    "loss, accuracy = export_model.evaluate(raw_test_ds)\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TwQgoN88LoEF"
   },
   "source": [
    "### 使用新数据进行推断\n",
    "\n",
    "要获得对新样本的预测，只需调用 `model.predict()` 即可。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:08:23.466999Z",
     "iopub.status.busy": "2023-11-08T00:08:23.466420Z",
     "iopub.status.idle": "2023-11-08T00:08:23.696099Z",
     "shell.execute_reply": "2023-11-08T00:08:23.695269Z"
    },
    "id": "QW355HH5L49K"
   },
   "outputs": [],
   "source": [
    "examples = [\n",
    "  \"The movie was great!\",\n",
    "  \"The movie was okay.\",\n",
    "  \"The movie was terrible...\"\n",
    "]\n",
    "\n",
    "export_model.predict(examples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MaxlpFWpzR6c"
   },
   "source": [
    "将文本预处理逻辑包含在模型中后，您可以导出用于生产的模型，从而简化部署并降低[训练/测试偏差](https://developers.google.com/machine-learning/guides/rules-of-ml#training-serving_skew)的可能性。\n",
    "\n",
    "在选择应用 TextVectorization 层的位置时，需要注意性能差异。在模型之外使用它可以让您在 GPU 上训练时进行异步 CPU 处理和数据缓冲。因此，如果您在 GPU 上训练模型，您应该在开发模型时使用此选项以获得最佳性能，然后在准备好部署时进行切换，在模型中包含 TextVectorization 层。\n",
    "\n",
    "请参阅此[教程](https://tensorflow.google.cn/tutorials/keras/save_and_load)，详细了解如何保存模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eSSuci_6nCEG"
   },
   "source": [
    "## 练习：对 Stack Overflow 问题进行多类分类\n",
    "\n",
    "本教程展示了如何在 IMDB 数据集上从头开始训练二元分类器。作为练习，您可以修改此笔记本以训练多类分类器来预测 [Stack Overflow](http://stackoverflow.com/) 上的编程问题的标签。\n",
    "\n",
    "我们已经准备好了一个[数据集](https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz)供您使用，其中包含了几千个发布在 Stack Overflow 上的编程问题（例如，\"How can sort a dictionary by value in Python?\"）。每一个问题都只有一个标签（Python、CSharp、JavaScript 或 Java）。您的任务是将问题作为输入，并预测适当的标签，在本例中为 Python。\n",
    "\n",
    "您将使用的数据集包含从 [BigQuery](https://console.cloud.google.com/marketplace/details/stack-exchange/stack-overflow) 上更大的公共 Stack Overflow 数据集提取的数千个问题，其中包含超过 1700 万个帖子。\n",
    "\n",
    "下载数据集后，您会发现它与您之前使用的 IMDB 数据集具有相似的目录结构：\n",
    "\n",
    "```\n",
    "train/\n",
    "...python/\n",
    "......0.txt\n",
    "......1.txt\n",
    "...javascript/\n",
    "......0.txt\n",
    "......1.txt\n",
    "...csharp/\n",
    "......0.txt\n",
    "......1.txt\n",
    "...java/\n",
    "......0.txt\n",
    "......1.txt\n",
    "```\n",
    "\n",
    "注：为了增加分类问题的难度，编程问题中出现的 Python、CSharp、JavaScript 或 Java 等词已被替换为 *blank*（因为许多问题都包含它们所涉及的语言）。\n",
    "\n",
    "要完成此练习，您应该对此笔记本进行以下修改以使用 Stack Overflow 数据集：\n",
    "\n",
    "1. 在笔记本顶部，将下载 IMDB 数据集的代码更新为下载前面准备好的 [Stack Overflow 数据集](https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz)的代码。由于 Stack Overflow 数据集具有类似的目录结构，因此您不需要进行太多修改。\n",
    "\n",
    "2. 将模型的最后一层修改为 `Dense(4)`，因为现在有四个输出类。\n",
    "\n",
    "3. 编译模型时，将损失更改为 `tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)`。当每个类的标签是整数（在本例中，它们可以是 0、*1*、*2* 或 *3*）时，这是用于多类分类问题的正确损失函数。 此外，将指标更改为 `metrics=['accuracy']`，因为这是一个多类分类问题（`tf.metrics.BinaryAccuracy` 仅用于二元分类器 ）。\n",
    "\n",
    "4. 在绘制随时间变化的准确率时，请将 `binary_accuracy` 和 `val_binary_accuracy` 分别更改为 `accuracy` 和 `val_accuracy`。\n",
    "\n",
    "5. 完成这些更改后，就可以训练多类分类器了。 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "F0T5SIwSm7uc"
   },
   "source": [
    "## 了解更多信息\n",
    "\n",
    "本教程从头开始介绍了文本分类。要详细了解一般的文本分类工作流程，请查看 Google Developers 提供的[文本分类指南](https://developers.google.com/machine-learning/guides/text-classification/)。\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "text_classification.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
